#include "AVLTree.hpp"
#include <string.h>
#include <stdlib.h>
#include <stdio.h>
#include <cmath>

TreeNode *buildNode(const char *data)
{
    TreeNode *newNode = (TreeNode *)malloc(sizeof(TreeNode));
    newNode->data = data;
    newNode->left = nullptr;
    newNode->right = nullptr;
    newNode->height = 0;
    return newNode;
}

int compare(TreeNode *n1, TreeNode *n2)
{
    return strcmp(n1->data, n2->data);
}
int AVLTree::maxHeight(TreeNode *n1, TreeNode *n2)
{
    int h1 = height(n1);
    int h2 = height(n2);
    return h1 > h2 ? h1 : h2;
}

AVLTree::AVLTree(const char *data)
{
    root = buildNode(data);
}

/**
 *  1
 *      6                  6
 *     /  \               / \
 *    5    9   ----->    5   8
 *   /    /             /   / \
 *   3   8             3   6   9
 *      /
 *     7
 *
 */
TreeNode *AVLTree::singleRotateLeft(TreeNode *node)
{
    printf("singleRotateLeft(node(%s))\n", node->data);
    TreeNode *left = node->left;
    node->left = node->right;
    left->right = node;
    node->height = maxHeight(node->left, node->right) + 1;
    left->height = maxHeight(left->left, left->right) + 1;
    return left;
}

/**
 * 4
 *      8                  8
 *     /  \               / \
 *    6    9   ----->    6   12
 *   /      \           /   / \
 *  3        12        3   9   20
 *            \
 *             20
 *
 */
TreeNode *AVLTree::singleRotateRight(TreeNode *node)
{
    printf("singleRotateRight(node(%s))\n", node->data);
    TreeNode *right = node->right;
    node->right = right->left;
    right->left = right;
    node->height = maxHeight(node->left, node->right) + 1;
    right->height = maxHeight(right->left, right->right) + 1;
    return right;
}

/**
 * 2
 *      8                  8                 6
 *     /  \               / \               / \
 *    5    9   ----->    6   9  ---->      5   8
 *   / \                / \               /   / \
 *  3   6              5   7             3   7   9
 *       \            /
 *        7          3
 *
 */
TreeNode *AVLTree::doubleRotateLeft(TreeNode *node)
{
    printf("doubleRotateLeft(node(%s))\n", node->data);
    node->left = singleRotateRight(node->left);
    return singleRotateLeft(node);
}

/**
 * 3
 *      4                  5                 6
 *     /  \               / \               / \
 *    2    7   ----->    2   6  ---->      4   7
 *        / \               / \           / \   \
 *       6   8             5   7         2   5   8
 *      /                       \
 *     5                         8
 *
 */
TreeNode *AVLTree::doubleRotateRight(TreeNode *node)
{
    printf("doubleRotateRight(node(%s))\n", node->data);
    node->right = singleRotateLeft(node->right);
    return singleRotateRight(node);
}

int AVLTree::height(TreeNode *node)
{
    if (node == nullptr)
        return -1;
    return node->height;
}

TreeNode *AVLTree::insert(const char *data)
{
    TreeNode *newNode = buildNode(data);
    return insert(root, newNode);
}

TreeNode *AVLTree::insert(TreeNode *parent, TreeNode *newNode)
{
    if (parent == nullptr)
    {
        parent = newNode;
        parent->height = maxHeight(parent->left, parent->right) + 1;
        return parent;
    }
    int cmpr = compare(parent, newNode);
    if (cmpr > 0)
    {
        parent->left = insert(parent->left, newNode);
        if (height(parent->left) - height(root->right) == 2)
        {
            if (compare(newNode, parent) < 0)
                root = singleRotateLeft(root);
            else
                root = doubleRotateLeft(root);
        }
    }
    else if (cmpr < 0)
    {
        parent->right = insert(parent->right, newNode);
        if (height(parent->right) - height(parent->left) == 2)
        {
            if (compare(newNode, parent) > 0)
                parent = singleRotateRight(parent);
            else
                parent = doubleRotateRight(parent);
        }
    }
    parent->height = maxHeight(parent->left, parent->right) + 1;
    return parent;
}
int AVLTree::isAVL(TreeNode *node)
{
    int left, right;
    if (node == nullptr)
    {
        return 0;
    }
    left = isAVL(node->left);
    if (left == -1)
        return left;
    right = isAVL(node->right);
    if (right == -1)
        return right;
    if (abs(left - right) > 1)
        return -1;
    return fmaxf(left, right) + 1;
}

int AVLTree::isAVL()
{
    return isAVL(root);
}

TreeNode *AVLTree::removeHalfNodes(TreeNode *node)
{
    if (node == nullptr)
    {
        return nullptr;
    }
    node->left = removeHalfNodes(node->left);
    node->right = removeHalfNodes(node->right);
    if (node->left == nullptr && node->right == nullptr)
        return node;
    if (node->left == nullptr)
    {
        TreeNode *right = node->right;
        free(node);
        return right;
    }

    if (node->right == nullptr)
    {
        TreeNode *left = node->left;
        free(node);
        return left;
    }
    return node;
}
TreeNode *AVLTree::removeLeaves(TreeNode *node)
{
    if (node != nullptr)
    {
        if (node->left == nullptr && node->right == nullptr)
        {
            free(node);
            return nullptr;
        }
        else
        {
            node->left = removeLeaves(node->left);
            node->right = removeLeaves(node->right);
        }
    }
    return node;
}

TreeNode *AVLTree::removeHalfNodes()
{
    return removeHalfNodes(root);
}

TreeNode *AVLTree::removeLeaves()
{
    return removeLeaves(root);
}
